
import numpy as np

a = np.array([
    [
        [1, 5, 5, 2],
        [9, -6, 2, 8],
        [-3, 7, -9, 1]
    ],

    [
        [-1, 5, -5, 2],
        [9, 6, 2, 8],
        [3, 7, 9, 1]
    ]
])
print(a.argmax())



